{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gett a 5s excerpt of first test file\n",
    "from pyannote.database import get_protocol, FileFinder\n",
    "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n",
    "                        preprocessors={\"audio\": FileFinder()})\n",
    "\n",
    "from pyannote.audio.core.io import Audio\n",
    "audio = Audio(sample_rate=16000, mono=True)\n",
    "file = next(protocol.test())\n",
    "\n",
    "from pyannote.core import Segment\n",
    "waveform, sample_rate = audio.crop(file, Segment(5, 10))\n",
    "\n",
    "import torch\n",
    "waveforms = torch.tensor(waveform)[None, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# play the excerpt\n",
    "from IPython.display import Audio as Play\n",
    "Play(waveforms.squeeze(), rate=sample_rate, normalize=False, autoplay=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a model that simply returns the waveform\n",
    "from pyannote.audio.core.model import Model\n",
    "class Passthrough(Model):\n",
    "    def forward(self, waveforms):\n",
    "        return waveforms\n",
    "    \n",
    "identity = Passthrough()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pass the waveform through this \"identity\" model\n",
    "Play(identity(waveforms).squeeze(), rate=sample_rate, normalize=False, autoplay=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add one torch_audiomentations waveform transform to the model\n",
    "from pyannote.audio.augmentation.registry import register_augmentation\n",
    "from torch_audiomentations import Gain\n",
    "gain = Gain(\n",
    "    min_gain_in_db=-15.0,\n",
    "    max_gain_in_db=5.0,\n",
    "    p=0.5)\n",
    "register_augmentation(gain, identity, when='input')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pass the waveform through the \"augmented\" model\n",
    "Play(identity(waveforms).squeeze(), rate=sample_rate, normalize=False, autoplay=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
